Skip to content

[Pallas] Use LONG_INT_TYPE for jagged offsets in examples and tests#2132

Draft
norx1991 wants to merge 1 commit into
mainfrom
yifeixu/jagged-long-int-type
Draft

[Pallas] Use LONG_INT_TYPE for jagged offsets in examples and tests#2132
norx1991 wants to merge 1 commit into
mainfrom
yifeixu/jagged-long-int-type

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

Summary

Follow-up to #1950 (which introduced LONG_INT_TYPE and applied it to cross_entropy). Extends the pattern to all jagged examples and their tests so offset tensors are int32 on Pallas/TPU and int64 elsewhere.

torch.cumsum on int32 silently promotes to int64, so dtype= is also passed to cumsum to keep offsets in LONG_INT_TYPE.

This unblocks the int64 input rejection in Pallas for the jagged tests; remaining xfails now hit their originally-documented JAX tracer / BlockSpec errors instead of the int64 rejection.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 28, 2026
@norx1991 norx1991 force-pushed the yifeixu/jagged-long-int-type branch from b87bc3b to ea23837 Compare April 29, 2026 21:16
Extend the LONG_INT_TYPE pattern from #1950 (cross_entropy) to all jagged
examples and their tests. Jagged offset tensors are now int32 on Pallas/TPU
and int64 elsewhere.

torch.cumsum on int32 silently promotes to int64, so the dtype= kwarg is
also passed to cumsum to keep offsets in LONG_INT_TYPE.

This unblocks the int64 input rejection in Pallas (introduced in #1950)
for the jagged tests; remaining xfails now hit their originally-documented
JAX tracer / BlockSpec errors.
@norx1991 norx1991 force-pushed the yifeixu/jagged-long-int-type branch from ea23837 to 64ba3a9 Compare April 29, 2026 21:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant